{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c8d01073-09a8-4c68-b93b-aef463738bd0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os, sys, glob, argparse\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "import cv2\n",
    "from PIL import Image\n",
    "from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n",
    "\n",
    "import torch\n",
    "torch.manual_seed(0)\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset\n",
    "\n",
    "# Check if GPU is available\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "70d2263d-916c-431e-aca7-3e5bcd867f9b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_path = glob.glob('./叶片病害识别挑战赛训练集/*/*')\n",
    "np.random.shuffle(train_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "46d68b4a-10c8-49f3-a71e-e52c86fe2bae",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3400"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "584f2188-cf45-4d72-abee-4dae0d362926",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class XunFeiDataset(Dataset):\n",
    "    def __init__(self, img_path, transform=None):\n",
    "        self.img_path = img_path\n",
    "        if transform is not None:\n",
    "            self.transform = transform\n",
    "        else:\n",
    "            self.transform = None\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(self.img_path[index])\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        \n",
    "        if 'powdery_mildew' in self.img_path[index]:\n",
    "            label = 0\n",
    "        elif 'healthy' in self.img_path[index]:\n",
    "            label = 1\n",
    "        elif 'rust' in self.img_path[index]:\n",
    "            label = 2\n",
    "        elif 'scab' in self.img_path[index]:\n",
    "            label = 3\n",
    "        \n",
    "        return img, torch.from_numpy(np.array(label).astype(int))\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "473a6765-dd3d-4f48-8d34-ab4c32c013b4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class XunFeiNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(XunFeiNet, self).__init__()\n",
    "        model = models.resnet50(True)\n",
    "        model.avgpool = nn.AdaptiveAvgPool2d(1)\n",
    "        model.fc = nn.Linear(2048, 4)\n",
    "        self.resnet = model\n",
    "    \n",
    "    def forward(self, img):\n",
    "        out = self.resnet(img)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "47539bef-14a7-4881-85e9-0882ff295e6a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train(train_loader, model, criterion, optimizer):\n",
    "    model.train()\n",
    "    train_loss = 0.0\n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        input = input.to(device)\n",
    "        target = target.to(device)\n",
    "\n",
    "        # compute output\n",
    "        output = model(input)\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        # compute gradient and do SGD step\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % 100 == 0:\n",
    "            print('Train loss', loss.item())\n",
    "            \n",
    "        train_loss += loss.item()\n",
    "    \n",
    "    return train_loss/len(train_loader)\n",
    "            \n",
    "def validate(val_loader, model, criterion):\n",
    "    model.eval()\n",
    "    \n",
    "    val_acc = 0.0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i, (input, target) in enumerate(val_loader):\n",
    "            input = input.to(device)\n",
    "            target = target.to(device)\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            loss = criterion(output, target)\n",
    "            \n",
    "            val_acc += (output.argmax(1) == target).sum().item()\n",
    "            \n",
    "    return val_acc / len(val_loader.dataset)\n",
    "\n",
    "def predict(test_loader, model, criterion):\n",
    "    model.eval()\n",
    "    val_acc = 0.0\n",
    "    \n",
    "    test_pred = []\n",
    "    with torch.no_grad():\n",
    "        for i, (input, target) in enumerate(test_loader):\n",
    "            input = input.to(device)\n",
    "            target = target.to(device)\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            test_pred.append(output.data.cpu().numpy())\n",
    "            \n",
    "    return np.vstack(test_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "c14b410f-de94-4f52-be5b-1377c34bda4a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(train_path[:-500],\n",
    "    transforms.Compose([\n",
    "                transforms.Resize(256),\n",
    "                transforms.RandomResizedCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.RandomVerticalFlip(),\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                     std=[0.229, 0.224, 0.225])\n",
    "    ])), batch_size=15, shuffle=True, num_workers=4, pin_memory=False\n",
    ")\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(train_path[-500:],\n",
    "    transforms.Compose([\n",
    "                transforms.Resize(256),\n",
    "                transforms.RandomResizedCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                     std=[0.229, 0.224, 0.225])\n",
    "    ])), batch_size=30, shuffle=False, num_workers=1, pin_memory=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "7f7ee8cb-cfc0-4eb8-96ea-7302144ae7a4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = XunFeiNet()\n",
    "model = model.to(device)\n",
    "criterion = nn.CrossEntropyLoss().cuda()\n",
    "optimizer = torch.optim.SGD(model.parameters(), 0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "b2faad78-e794-4f31-8b56-1a64dbf89ccc",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss 1.5940135717391968\n",
      "Train loss 1.1077581644058228\n",
      "1.1705250399014384 0.716\n",
      "Train loss 0.9211582541465759\n",
      "Train loss 0.7091454267501831\n",
      "0.8290592757696958 0.83\n",
      "Train loss 0.6219263076782227\n",
      "Train loss 0.47454729676246643\n",
      "0.6132681077903079 0.828\n",
      "Train loss 0.45259857177734375\n",
      "Train loss 0.5992544293403625\n",
      "0.5223359434106916 0.864\n",
      "Train loss 0.560684084892273\n",
      "Train loss 0.24091745913028717\n",
      "0.4466889486792161 0.87\n",
      "Train loss 0.471619188785553\n",
      "Train loss 0.3848036527633667\n",
      "0.3995219959502982 0.864\n",
      "Train loss 0.3464297354221344\n",
      "Train loss 0.33798450231552124\n",
      "0.3911478723754588 0.886\n",
      "Train loss 0.3078201711177826\n",
      "Train loss 0.36544033885002136\n",
      "0.3732020866855518 0.908\n",
      "Train loss 0.8368951082229614\n",
      "Train loss 0.23293063044548035\n",
      "0.35171481525314224 0.898\n",
      "Train loss 0.245680570602417\n",
      "Train loss 0.29640594124794006\n",
      "0.3521484049785997 0.916\n"
     ]
    }
   ],
   "source": [
    "for _  in range(10):\n",
    "    train_loss = train(train_loader, model, criterion, optimizer)\n",
    "    val_acc  = validate(val_loader, model, criterion)\n",
    "    \n",
    "    print(train_loss, val_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "4c342ae2-1004-41af-bbe5-0b76f4d46a41",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to('cpu')\n",
    "torch.save(model.state_dict(), 'model.pt')"
   ]
  },
  {
   "cell_type": "raw",
   "id": "f9b8c959-cc6c-4389-9f97-1fcced231644",
   "metadata": {},
   "source": [
    "文件夹可以组织为如下格式：\n",
    "leafs-test/\n",
    "leafs-test/model/\n",
    "leafs-test/model/model.pt\n",
    "leafs-test/.ipynb_checkpoints/\n",
    "leafs-test/.ipynb_checkpoints/run-checkpoint.py\n",
    "leafs-test/run.py\n",
    "\n",
    "tar -cvzf leafs-test.tar.gz leafs-test/\n",
    "s3cmd put leafs-test.tar.gz s3://ai-competition/你的URL/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81aab234-cff6-4866-99c7-ef4e13080700",
   "metadata": {},
   "source": [
    "run.py 内容如下：\n",
    "\n",
    "```python\n",
    "import os, sys, glob, argparse\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "import cv2\n",
    "from PIL import Image\n",
    "from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n",
    "\n",
    "import torch\n",
    "torch.manual_seed(0)\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset\n",
    "\n",
    "# Check if GPU is available\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "class XunFeiDataset(Dataset):\n",
    "    def __init__(self, img_path, transform=None):\n",
    "        self.img_path = img_path\n",
    "        if transform is not None:\n",
    "            self.transform = transform\n",
    "        else:\n",
    "            self.transform = None\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(self.img_path[index])\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        \n",
    "        if 'powdery_mildew' in self.img_path[index]:\n",
    "            label = 0\n",
    "        elif 'healthy' in self.img_path[index]:\n",
    "            label = 1\n",
    "        elif 'rust' in self.img_path[index]:\n",
    "            label = 2\n",
    "        elif 'scab' in self.img_path[index]:\n",
    "            label = 3\n",
    "        else:\n",
    "            label = 0\n",
    "        \n",
    "        return img, torch.from_numpy(np.array(label).astype(int))\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.img_path)\n",
    "    \n",
    "class XunFeiNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(XunFeiNet, self).__init__()\n",
    "        model = models.resnet50(False)\n",
    "        model.avgpool = nn.AdaptiveAvgPool2d(1)\n",
    "        model.fc = nn.Linear(2048, 4)\n",
    "        self.resnet = model\n",
    "    \n",
    "    def forward(self, img):\n",
    "        out = self.resnet(img)\n",
    "        return out\n",
    "    \n",
    "def predict(test_loader, model):\n",
    "    model.eval()    \n",
    "    test_pred = []\n",
    "    with torch.no_grad():\n",
    "        for i, (input, target) in enumerate(test_loader):\n",
    "            input = input.to(device)\n",
    "            target = target.to(device)\n",
    "            output = model(input)\n",
    "            test_pred.append(output.data.cpu().numpy())\n",
    "            \n",
    "    return np.vstack(test_pred)\n",
    "\n",
    "\n",
    "test_path = glob.glob('/work/data/leafs-test-dataset/*')\n",
    "test_path.sort()\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(test_path[:],\n",
    "    transforms.Compose([\n",
    "                transforms.Resize(224),\n",
    "                # transforms.RandomResizedCrop(224),\n",
    "                # transforms.RandomHorizontalFlip(),\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                     std=[0.229, 0.224, 0.225])\n",
    "    ])), batch_size=30, shuffle=False, num_workers=1, pin_memory=False\n",
    ")\n",
    "model = XunFeiNet()\n",
    "model.load_state_dict(torch.load('./model/model.pt'))\n",
    "\n",
    "test_pred = predict(test_loader, model)\n",
    "test_pred = test_pred.argmax(1)\n",
    "class_names = np.array(['powdery_mildew', 'healthy', 'rust', 'scab'])\n",
    "test_pred = class_names[test_pred]\n",
    "\n",
    "pd.DataFrame({\n",
    "    'uuid': [x.split('/')[-1] for x in test_path],\n",
    "    'label': test_pred\n",
    "}).to_csv('/work/output/result.csv', index=None)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea00dec7-f842-459a-b2c6-e108dc9b522b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12b893d0-d22e-401a-b39a-4810c7ec7e6a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
